import itertools
from copy import deepcopy
from collections import defaultdict
from torch.utils.data import random_split, Subset, Dataset
from collections import Counter
import logging
from sympy.stats import Bernoulli, sample_iter
from random import randint
from torchvision.datasets import CIFAR10
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Subset
from copy import deepcopy
import random


log = logging.getLogger(__name__)

def get_targets(dataset):
    if isinstance(dataset, Subset):
        targets = dataset.dataset.targets if hasattr(dataset.dataset, 'targets') else dataset.dataset.labels
    else:
        targets = dataset.targets if hasattr(dataset, 'targets') else dataset.labels
    targets = np.array(targets) if isinstance(targets, torch.Tensor) else np.array(targets)
    return targets.flatten()



def pathological_client_data_split_medmnist(dataset, num_clients, num_classes_per_client=3, min_class_samples=400, seed=None, variable_dataset_size=True):
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    k = num_classes_per_client
    print(f"num_classes_per_client: {num_classes_per_client}")

    indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
    targets = get_targets(dataset)

    df = pd.DataFrame({"target": targets[indices]}, index=indices)
    label_to_indices = {label: group.index.tolist() for label, group in df.groupby('target')}

    num_samples_per_class = {cls: len(label_to_indices[cls]) for cls in label_to_indices}
    print(f"label_to_indices: {num_samples_per_class}")

    classes = np.array(list(label_to_indices.keys()))
    np.random.shuffle(classes)
    initial_classes = np.array_split(classes, num_clients)

    client_classes = {i: set(initial_classes[i]) for i in range(num_clients)}
    for client in client_classes:
        while len(client_classes[client]) < k:
            client_classes[client].add(np.random.choice(classes))

    # Adjust min_class_samples for each class if needed
    adjusted_min_samples_per_class = {}
    for cls in classes:
        assigned_clients = [client for client in client_classes if cls in client_classes[client]]
        if len(assigned_clients) * min_class_samples > num_samples_per_class[cls]:
            adjusted_min_samples_per_class[cls] = num_samples_per_class[cls] // len(assigned_clients)
            print(f"Class {cls} has limited samples. Adjusted min_class_samples: {adjusted_min_samples_per_class[cls]}")
        else:
            adjusted_min_samples_per_class[cls] = min_class_samples

    client_indices = {i: [] for i in range(num_clients)}
    for client in client_classes:
        print(f"Client {client} is assigned classes: {sorted(client_classes[client])}")
        for cls in client_classes[client]:
            available_indices = label_to_indices[cls]
            adjusted_min_samples = adjusted_min_samples_per_class[cls]
            chosen_indices = np.random.choice(available_indices, adjusted_min_samples, replace=False)
            client_indices[client].extend(chosen_indices)
            label_to_indices[cls] = list(set(available_indices) - set(chosen_indices))
            print(f"Client {client} received {len(chosen_indices)} samples from class {cls} (adjusted min_class_samples: {adjusted_min_samples})")

    if variable_dataset_size:
        print(">> Adding variability in dataset sizes ...")
        max_index = len(targets)
        for client in client_indices:
            for cls in client_classes[client]:
                available_indices = [i for i in label_to_indices[cls] if i < max_index]
                if available_indices:
                    extra_samples = np.random.randint(0, len(available_indices))
                    extra_indices = np.random.choice(available_indices, extra_samples, replace=False)
                    client_indices[client].extend(extra_indices)
                    label_to_indices[cls] = list(set(available_indices) - set(extra_indices))
                    print(f"Client {client} received an additional {len(extra_indices)} samples from class {cls}")

    for client_idx in range(num_clients):
        print(f"Client {client_idx} dataset size: {len(client_indices[client_idx])}")

    max_index = len(targets)
    client_datasets = []
    for client_idx in range(num_clients):
        client_dataset_indices = client_indices[client_idx]
        client_dataset_indices = [i for i in client_dataset_indices if i < max_index]  # Ensure indices are in range
        if any(i >= max_index for i in client_dataset_indices):
            raise IndexError(f"Client {client_idx} has indices out of bounds. Max index: {max_index}, Indices: {client_dataset_indices}")
        print(f"Client {client_idx} indices range: {min(client_dataset_indices)} to {max(client_dataset_indices)}")
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = client_dataset_indices
            client_datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=client_dataset_indices)
            client_datasets.append(subset)

    return client_datasets



def pathological_client_data_split(dataset, num_clients, num_classes_per_client=3, min_class_samples=400, seed=None, variable_dataset_size=True):
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    k = num_classes_per_client
    print(f"num_classes_per_client: {num_classes_per_client}")

    indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
    targets = get_targets(dataset)

    df = pd.DataFrame({"target": targets[indices]}, index=indices)
    label_to_indices = {label: group.index.tolist() for label, group in df.groupby('target')}

    num_samples_per_class = [len(label_to_indices[cls]) for cls in label_to_indices]
    print(f"label_to_indices: {num_samples_per_class}")

    classes = np.array(list(label_to_indices.keys()))
    np.random.shuffle(classes)
    initial_classes = np.array_split(classes, num_clients)

    client_classes = {i: set(initial_classes[i]) for i in range(num_clients)}
    for client in client_classes:
        while len(client_classes[client]) < k:
            client_classes[client].add(np.random.choice(classes))

    client_indices = {i: [] for i in range(num_clients)}
    for client in client_classes:
        print(f"Client {client} is assigned classes: {sorted(client_classes[client])}")
        for cls in client_classes[client]:
            available_indices = label_to_indices[cls]
            if len(available_indices) < min_class_samples:
                print(f"class {cls}, len(available_indices): {len(available_indices)}")
                raise ValueError(f"Not enough samples for class {cls} to meet min_class_samples requirement")
            chosen_indices = np.random.choice(available_indices, min_class_samples, replace=False)
            client_indices[client].extend(chosen_indices)
            label_to_indices[cls] = list(set(available_indices) - set(chosen_indices))
            print(f"Client {client} received {len(chosen_indices)} samples from class {cls}")

    if variable_dataset_size:
        print(">> Adding variability in dataset sizes ...")
        max_index = len(targets)
        for client in client_indices:
            for cls in client_classes[client]:
                available_indices = [i for i in label_to_indices[cls] if i < max_index]
                if available_indices:
                    extra_samples = np.random.randint(0, len(available_indices))
                    extra_indices = np.random.choice(available_indices, extra_samples, replace=False)
                    client_indices[client].extend(extra_indices)
                    label_to_indices[cls] = list(set(available_indices) - set(extra_indices))
                    print(f"Client {client} received an additional {len(extra_indices)} samples from class {cls}")

    for client_idx in range(num_clients):
        print(f"Client {client_idx} dataset size: {len(client_indices[client_idx])}")

    max_index = len(targets)
    client_datasets = []
    for client_idx in range(num_clients):
        client_dataset_indices = client_indices[client_idx]
        client_dataset_indices = [i for i in client_dataset_indices if i < max_index]  # Ensure indices are in range
        if any(i >= max_index for i in client_dataset_indices):
            raise IndexError(f"Client {client_idx} has indices out of bounds. Max index: {max_index}, Indices: {client_dataset_indices}")
        print(f"Client {client_idx} indices range: {min(client_dataset_indices)} to {max(client_dataset_indices)}")
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = client_dataset_indices
            client_datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=client_dataset_indices)
            client_datasets.append(subset)

    return client_datasets




def pathological_client_data_split_v2(dataset, num_clients, num_classes_per_client=20, min_class_samples=200, seed=None, variable_dataset_size=True):
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    k = num_classes_per_client
    indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
    targets = get_targets(dataset)

    df = pd.DataFrame({"target": targets[indices]}, index=indices)
    label_to_indices = {label: group.index.tolist() for label, group in df.groupby('target')}

    # Divide each class into two portions with the specified variability in size
    class_portions = {}
    for label, indices_list in label_to_indices.items():
        np.random.shuffle(indices_list)
        max_first_portion_size = min_class_samples + (len(indices_list) - 2 * min_class_samples)
        if max_first_portion_size < min_class_samples:
            raise ValueError(f"Not enough samples for class {label} to create portions with the required sizes.")
        first_portion_size = np.random.randint(min_class_samples, max_first_portion_size + 1)
        class_portions[label] = [
            indices_list[:first_portion_size],
            indices_list[first_portion_size:]
        ]

    available_portions = []
    for label, portions in class_portions.items():
        for i, portion in enumerate(portions):
            available_portions.append((label, i, portion))

    np.random.shuffle(available_portions)
    client_indices = {i: [] for i in range(num_clients)}
    client_classes = {i: set() for i in range(num_clients)}

    for client in range(num_clients):
        for _ in range(k):
            portion = available_portions.pop(0)
            label, portion_idx, indices_list = portion
            client_classes[client].add(label)
            client_indices[client].extend(indices_list)
            print(f"Client {client} received {len(indices_list)} samples from class {label}, portion {portion_idx}")

    for client_idx in range(num_clients):
        print(f"Client {client_idx} dataset size: {len(client_indices[client_idx])}")

    max_index = len(targets)
    client_datasets = []
    for client_idx in range(num_clients):
        client_dataset_indices = client_indices[client_idx]
        client_dataset_indices = [i for i in client_dataset_indices if i < max_index]  # Ensure indices are in range
        if any(i >= max_index for i in client_dataset_indices):
            raise IndexError(
                f"Client {client_idx} has indices out of bounds. Max index: {max_index}, Indices: {client_dataset_indices}")
        print(f"Client {client_idx} indices range: {min(client_dataset_indices)} to {max(client_dataset_indices)}")
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = client_dataset_indices
            client_datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=client_dataset_indices)
            client_datasets.append(subset)

    return client_datasets


def dirichlet_split(dataset, num_clients, seed, alpha_value=0.1, min_size_of_dataset=10):
    np.random.seed(seed)
    print(f"alpha = {alpha_value}")

    indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
    targets = get_targets(dataset)

    df = pd.DataFrame({"target": targets[indices]}, index=indices)
    label_to_indices = {label: group.index for label, group in df.groupby('target')}

    labels, classes_count_ = np.unique(df.target, return_counts=True)
    classes_count = defaultdict(int)
    for label, count in zip(labels, classes_count_):
        classes_count[label] = count

    current_min_size = 0

    while current_min_size < min_size_of_dataset:
        client_indices = defaultdict(list)
        client_indices_per_class = defaultdict(dict)
        for cls_idx in df.target.unique():
            alpha = np.ones(num_clients) * alpha_value
            print(f"using vector alpha:{alpha}")

            dist_of_cls_idx_across_clients = np.random.dirichlet(alpha, size=1)[0]
            freq = (np.cumsum(dist_of_cls_idx_across_clients) * classes_count[cls_idx]).astype(int)[:-1]
            assign = np.split(label_to_indices[cls_idx], freq)
            for client_idx, client_cls_indices in enumerate(assign):
                client_indices[client_idx].extend(client_cls_indices)
                client_indices_per_class[client_idx][cls_idx] = client_cls_indices

            current_min_size = min([len(client_indices[i]) for i in range(num_clients)])

    print(
        f"len(df) {len(df)}, len([idx for _, indices in client_indices.items() for idx in indices]): {len([idx for _, indices in client_indices.items() for idx in indices])}")
    print(f">>>> len of each partition: {[len(indices) for _, indices in client_indices.items()]}")
    assert len(df) == len([idx for _, indices in client_indices.items() for idx in indices])
    assert all(set(p0).isdisjoint(set(p1)) for p0, p1 in
               itertools.combinations([indices for _, indices in client_indices.items()], 2))

    datasets = []
    for client_idx in range(num_clients):
        indices = client_indices[client_idx]
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = indices
            datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=indices)
            datasets.append(subset)

    return datasets

#
# def pathological_client_data_split(dataset, num_clients, num_classes_per_client=3, min_class_samples=400, seed=None, variable_dataset_size=True):
#     """
#         Distributes dataset indices among multiple clients with a minimum number of samples per class,
#         simulating a pathological non-IID distribution typical in federated learning setups. Each client is assigned
#         a fixed number of unique classes, and optionally, additional samples can be distributed to introduce variability
#         in dataset sizes among clients.
#
#         Parameters:
#         - dataset (torch.utils.data.Dataset or torch.utils.data.Subset): The dataset or subset from which indices are distributed.
#         - num_clients (int): The number of clients among which the dataset is split.
#         - num_classes_per_client (int): The number of unique classes each client will receive.
#         - min_class_samples (int): The minimum number of samples per class to distribute to each client.
#         - seed (int, optional): Seed for random number generators to ensure reproducibility.
#         - variable_dataset_size (bool): If True, clients may receive additional random samples, leading to variable dataset sizes.
#
#         Returns:
#         - list of torch.utils.data.Subset: Each subset corresponds to the dataset of a specific client.
#         """
#     if seed is not None:
#         np.random.seed(seed)
#         random.seed(seed)
#         torch.manual_seed(seed)
#
#     k = num_classes_per_client
#     print(f"num_classes_per_client: {num_classes_per_client}")
#
#     indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
#     targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets
#     targets = np.array(targets) if isinstance(targets, torch.Tensor) else np.array(targets)
#
#     df = pd.DataFrame({"target": targets[indices]}, index=indices)
#     label_to_indices = {label: group.index.tolist() for label, group in df.groupby('target')}
#
#     classes = np.array(list(label_to_indices.keys()))
#     np.random.shuffle(classes)
#     initial_classes = np.array_split(classes, num_clients)
#
#     client_classes = {i: set(initial_classes[i]) for i in range(num_clients)}
#     for client in client_classes:
#         while len(client_classes[client]) < k:
#             client_classes[client].add(np.random.choice(classes))
#
#     client_indices = {i: [] for i in range(num_clients)}
#     for client in client_classes:
#         print(f"Client {client} is assigned classes: {sorted(client_classes[client])}")
#         for cls in client_classes[client]:
#             available_indices = label_to_indices[cls]
#             if len(available_indices) < min_class_samples:
#                 raise ValueError(f"Not enough samples for class {cls} to meet min_class_samples requirement")
#             chosen_indices = np.random.choice(available_indices, min_class_samples, replace=False)
#             client_indices[client].extend(chosen_indices)
#             label_to_indices[cls] = list(set(available_indices) - set(chosen_indices))
#             print(f"Client {client} received {len(chosen_indices)} samples from class {cls}")
#
#     # Adding variability in dataset sizes
#     if variable_dataset_size:
#         print(">> Adding variability in dataset sizes ...")
#         for client in client_indices:
#             for cls in client_classes[client]:
#                 available_indices = label_to_indices[cls]
#                 if available_indices:
#                     extra_samples = np.random.randint(0, len(available_indices))
#                     extra_indices = np.random.choice(available_indices, extra_samples, replace=False)
#                     client_indices[client].extend(extra_indices)
#                     label_to_indices[cls] = list(set(available_indices) - set(extra_indices))
#                     print(f"Client {client} received an additional {len(extra_indices)} samples from class {cls}")
#
#     # Create client datasets
#     client_datasets = []
#     for client_idx in range(num_clients):
#         client_dataset_indices = client_indices[client_idx]
#         if isinstance(dataset, Subset):
#             subset = deepcopy(dataset)
#             subset.indices = client_dataset_indices
#             client_datasets.append(subset)
#         else:
#             subset = Subset(dataset=dataset, indices=client_dataset_indices)
#             client_datasets.append(subset)
#
#     return client_datasets
#
#
#
# def dirichlet_split(
#         dataset, num_clients, seed, alpha_value=0.05, min_size_of_dataset=10
# ):
#     np.random.seed(seed)
#     print(f"alpha = {alpha_value}")
#
#     indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
#     targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets
#     targets = targets.numpy() if isinstance(targets, torch.Tensor) else np.array(targets)
#     df = pd.DataFrame({"target": targets[indices]}, index=indices)
#     label_to_indices = {}
#     # Map indices to classes (labels, targets)
#     for label, group in df.groupby('target'):
#         label_to_indices[label] = group.index
#     labels, classes_count_ = np.unique(df.target, return_counts=True)
#     classes_count = defaultdict(int)
#     for label, count in zip(labels, classes_count_):
#         classes_count[label] = count
#     # client_indices = defaultdict(list)
#     # client_indices_per_class = defaultdict(dict)
#     current_min_size = 0
#
#     while current_min_size < min_size_of_dataset:
#         client_indices = defaultdict(list)
#         client_indices_per_class = defaultdict(dict)
#         for cls_idx in df.target.unique():
#             alpha = np.ones(num_clients) * alpha_value
#             # alpha = [0.1,0.5,0.1,0.5,0.1,0.1,0.5,0.1,0.5,0.1] # TODO: just for an experiment
#             print(f"using vector alpha:{alpha}")
#
#             dist_of_cls_idx_across_clients = np.random.dirichlet(alpha, size=1)
#             dist_of_cls_idx_across_clients = dist_of_cls_idx_across_clients[0]
#             freq = (
#                            np.cumsum(dist_of_cls_idx_across_clients) * classes_count[cls_idx]
#                    ).astype(int)[:-1]
#             assign = np.split(label_to_indices[cls_idx], freq)
#             for client_idx, client_cls_indicies in enumerate(assign):
#                 client_indices[client_idx].extend(client_cls_indicies)
#                 client_indices_per_class[client_idx][cls_idx] = client_cls_indicies
#
#             current_min_size = min([len(client_indices[i]) for i in range(num_clients)])
#     print(f"len(df) {len(df)},len([idx for _, indices in client_indices.items() for idx in indices]): {len([idx for _, indices in client_indices.items() for idx in indices])}")
#     print(f">>>> len of each partition: {[len(indices) for _, indices in client_indices.items()]}")
#     assert len(df) == len([idx for _, indices in client_indices.items() for idx in indices])
#     # assert that there is no intersection between clients indices!
#     assert all((set(p0).isdisjoint(set(p1))) for p0, p1 in
#                itertools.combinations([indices for _, indices in client_indices.items()], 2))
#     datasets = []
#     for client_idx in range(num_clients):
#         indices = client_indices[client_idx]
#         if isinstance(dataset, Subset):
#             subset = deepcopy(dataset)
#             subset.indices = indices
#             datasets.append(subset)
#         else:
#             subset = Subset(dataset=dataset, indices=indices)
#             # subset.__class__.__getattr__ = new_getattr  # trick to get the attrs of the original dataset
#             datasets.append(subset)
#     return datasets


import numpy as np
import torch
import pandas as pd
from torch.utils.data import Subset
from copy import deepcopy
import random


def practical_noniid_client_data_split(dataset, num_clients, data_fraction=0.5, seed=42):
    """
    Distributes dataset indices among multiple clients in a practical non-IID manner.
    Each class of the dataset is partitioned into shards of varying sizes and
    distributed randomly among the clients such that every client possesses data
    from every class.

    This distribution was inspired by the paper: APPLE
    (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10184830/).
    ”The practical non-IID setting is more similar to the real-world FL in medical applications.”

    Parameters:
    - dataset (torch.utils.data.Dataset or torch.utils.data.Subset): The dataset or subset from which indices are distributed.
    - num_clients (int): The number of clients among which the dataset is split.
    - data_fraction (float): Fraction of the total data to be used. (0.0 < data_fraction <= 1.0)
    - seed (int, optional): Seed for random number generators to ensure reproducibility.

    Returns:
    - list of torch.utils.data.Subset: Each subset corresponds to the dataset of a specific client.
    """
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    # Get the indices and targets of the dataset
    indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
    targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets
    targets = np.array(targets) if isinstance(targets, torch.Tensor) else np.array(targets)

    # Sample a fraction of the total data
    if data_fraction < 1.0:
        num_samples = int(len(indices) * data_fraction)
        sampled_indices = np.random.choice(indices, num_samples, replace=False)
        indices = sampled_indices
        targets = targets[sampled_indices]

    df = pd.DataFrame({"target": targets}, index=indices)
    label_to_indices = {label: group.index.tolist() for label, group in df.groupby('target')}

    client_indices = {i: [] for i in range(num_clients)}

    for cls, idxs in label_to_indices.items():
        np.random.shuffle(idxs)
        num_1_percent_shards = num_clients - 2
        remaining_percent = 100 - (80 + num_1_percent_shards)

        shard_sizes = [1] * num_1_percent_shards + [80, remaining_percent]
        shard_sizes = np.array(shard_sizes) / 100.0  # Convert to fractions
        shard_sizes = (shard_sizes * len(idxs)).astype(int)  # Get number of samples per shard

        # Ensure the sum of shard sizes equals the total number of indices
        shard_sizes[-1] += len(idxs) - shard_sizes.sum()

        print(f"Class {cls}: Shard sizes (in samples) - {shard_sizes.tolist()}")

        shards = np.split(idxs, np.cumsum(shard_sizes)[:-1])

        shard_indices = list(range(len(shards)))
        np.random.shuffle(shard_indices)

        for client_idx in range(num_clients):
            shard_idx = shard_indices[client_idx % len(shard_indices)]
            client_indices[client_idx].extend(shards[shard_idx])
            shard_indices.remove(shard_idx)

    client_datasets = []
    for client_idx in range(num_clients):
        client_dataset_indices = client_indices[client_idx]
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = client_dataset_indices
            client_datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=client_dataset_indices)
            client_datasets.append(subset)

        print(f"Client {client_idx} has {len(client_dataset_indices)} samples.")

    return client_datasets


# def practical_noniid_client_data_split(dataset, num_clients, seed=42):
#     """
#     Distributes dataset indices among multiple clients in a practical non-IID manner.
#     Each class of the dataset is partitioned into shards of varying sizes and
#     distributed randomly among the clients such that every client possesses data
#     from every class.
#
#     This distribution was inspired by the paper: APPLE
#     (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10184830/).
#     ”The practical non-IID setting is more similar to the real-world FL in medical applications.”
#
#     Parameters:
#     - dataset (torch.utils.data.Dataset or torch.utils.data.Subset): The dataset or subset from which indices are distributed.
#     - num_clients (int): The number of clients among which the dataset is split.
#     - seed (int, optional): Seed for random number generators to ensure reproducibility.
#
#     Returns:
#     - list of torch.utils.data.Subset: Each subset corresponds to the dataset of a specific client.
#     """
#     if seed is not None:
#         np.random.seed(seed)
#         random.seed(seed)
#         torch.manual_seed(seed)
#
#     indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
#     targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets
#     targets = np.array(targets) if isinstance(targets, torch.Tensor) else np.array(targets)
#
#     df = pd.DataFrame({"target": targets[indices]}, index=indices)
#     label_to_indices = {label: group.index.tolist() for label, group in df.groupby('target')}
#
#     client_indices = {i: [] for i in range(num_clients)}
#
#     for cls, idxs in label_to_indices.items():
#         np.random.shuffle(idxs)
#         num_1_percent_shards = num_clients - 2
#         remaining_percent = 100 - (80 + num_1_percent_shards)
#
#         shard_sizes = [1] * num_1_percent_shards + [80, remaining_percent]
#         shard_sizes = np.array(shard_sizes) / 100.0  # Convert to fractions
#         shard_sizes = (shard_sizes * len(idxs)).astype(int)  # Get number of samples per shard
#
#         # Ensure the sum of shard sizes equals the total number of indices
#         shard_sizes[-1] += len(idxs) - shard_sizes.sum()
#
#         print(f"Class {cls}: Shard sizes (in samples) - {shard_sizes.tolist()}")
#
#         shards = np.split(idxs, np.cumsum(shard_sizes)[:-1])
#
#         shard_indices = list(range(len(shards)))
#         np.random.shuffle(shard_indices)
#
#         for client_idx in range(num_clients):
#             shard_idx = shard_indices[client_idx % len(shard_indices)]
#             client_indices[client_idx].extend(shards[shard_idx])
#             shard_indices.remove(shard_idx)
#
#     client_datasets = []
#     for client_idx in range(num_clients):
#         client_dataset_indices = client_indices[client_idx]
#         if isinstance(dataset, Subset):
#             subset = deepcopy(dataset)
#             subset.indices = client_dataset_indices
#             client_datasets.append(subset)
#         else:
#             subset = Subset(dataset=dataset, indices=client_dataset_indices)
#             client_datasets.append(subset)
#
#         print(f"Client {client_idx} has {len(client_dataset_indices)} samples.")
#
#     return client_datasets




def spec_nonIID(training_dataset,seed, num_clients=10, classes_per_client=1, own_class_percent=0.91):
    np.random.seed(seed)
    all_labels = [training_dataset.dataset.targets[i] for i in training_dataset.indices]
    client_indices = {k: [] for k in range(num_clients)}

    # Determine the classes each client will own
    all_classes = np.arange(10)
    np.random.shuffle(all_classes)
    owned_classes = np.array_split(all_classes, num_clients)

    for client_id, owned in enumerate(owned_classes):
        for class_id in owned:
            class_indices = [i for i in training_dataset.indices if training_dataset.dataset.targets[i] == class_id]
            own_indices, distribute_indices = train_test_split(class_indices, test_size=1-own_class_percent)
            client_indices[client_id].extend(own_indices)

            distribute_indices_splits = np.array_split(distribute_indices, num_clients - 1)
            distribute_count = 0
            for other_client_id in range(num_clients):
                if other_client_id != client_id:
                    client_indices[other_client_id].extend(distribute_indices_splits[distribute_count].tolist())
                    distribute_count += 1

    # Equalize the number of samples for each client
    max_samples_per_client = min(len(indices) for indices in client_indices.values())
    for client_id in range(num_clients):
        np.random.shuffle(client_indices[client_id])
        client_indices[client_id] = client_indices[client_id][:max_samples_per_client]

    client_datasets = [Subset(training_dataset.dataset, indices) for indices in client_indices.values()]

    return client_datasets



def quant_dirichlet_split(all_train_indices, num_clients, seed, alpha=0.5, min_samples=50):
    np.random.seed(seed)

    num_data = len(all_train_indices) - num_clients * min_samples
    assert num_data > 0, "Not enough data to satisfy the minimum samples for each client."

    # Sample proportions for each client from the Dirichlet distribution
    base_proportions = np.random.dirichlet(alpha=np.repeat(alpha, num_clients))

    # Scale the base proportions to allocate the remaining data after reserving min_samples for each client
    scaled_proportions = (base_proportions * num_data).astype(int)

    # Ensure that each client gets at least min_samples
    samples_per_client = scaled_proportions + min_samples

    # If the total samples allocated are less than the dataset size, allocate the remaining to random clients
    remaining_samples = len(all_train_indices) - samples_per_client.sum()
    samples_per_client[:remaining_samples] += 1

    assert samples_per_client.sum() == len(all_train_indices), "Allocated samples do not sum up to the dataset size."

    # Shuffle the dataset indices
    all_train_indices = np.random.permutation(all_train_indices)

    # Allocate indices to each client based on the sampled proportions
    client_indices = []
    start = 0
    for i in range(num_clients):
        end = start + samples_per_client[i]
        client_indices.append(all_train_indices[start:end])
        start = end

    return client_indices


def random_chunk_data_split(  # to support the splits with minimum number of samples per class
        dataset, num_clients, seed, num_classes=10,
        min_size_per_class=400
):
    np.random.seed(seed)
    random.seed(seed)
    print("num_classes:", num_classes)
    num_clients = 20
    # num_classes = 100
    print("num_classes:", num_classes)
    # get the binary dist of clients (i.e., what missing classes)
    clients_lists = defaultdict(list)
    for i in range(num_clients):

        while True:
            p = random.random()

            # if i % 7 == 0:
            #     p = random.uniform(0.5, 0.6)
            #
            # if i % 3 == 0:  ## 5 for r4, 4 for r5, 3 for r6, 8 r7
            #     p = random.uniform(0.9, 0.999)  # from 8 in r4, from 9 in r5

            clients_lists[i] = [
                int(np.random.choice([0, 1], replace=False, size=1, p=[1 - p, p])) for _ in range(num_classes)
            ]

            clients_lists[i] = list(sample_iter(Bernoulli('X', p), numsamples=num_classes))

            if sum(clients_lists[i]) >= 5:
                break
    print(clients_lists)
    ############
    # for the number of splits of each class (number of the clients who have this class)
    # classes_splits = np.zeros(num_classes)
    classes_lists = []

    for c in range(num_classes):
        count = 0
        cl_ls = []

        for i in range(num_clients):
            if clients_lists[i][c] == 1:
                # count += 1
                cl_ls.append(i)
        classes_lists.append(cl_ls)
        # classes_splits[c] = count
    ############

    print("classes_lists:", classes_lists)
    indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
    # targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets
    # targets = targets.numpy() if isinstance(targets, torch.Tensor) else np.array(targets)

    targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets
    df = pd.DataFrame({"target": targets[indices]}, index=indices)
    label_to_indices = {}
    for label, group in df.groupby('target'):
        label_to_indices[label] = set(group.index)

    client_dist_per_class = np.zeros((num_clients, num_classes))
    client_indices_per_class = defaultdict(lambda: defaultdict(dict))
    clients_splits = defaultdict(list)

    for i in range(len(classes_lists)):
        num_of_splits = int(len(
            label_to_indices[i]) / min_size_per_class)  #
        splits = np.array_split(list(label_to_indices[i]), num_of_splits)
        print("splits:", len(splits))

        used_splits = 0

        for j in (classes_lists[i]):  # give each client a single spilt initially
            print(f"classes_lists[{i}]:", classes_lists[i])
            print(f"clients_splits[{j}].extend(splits{[used_splits]})")
            if (used_splits <= len(splits) - 1):
                clients_splits[j].extend(splits[used_splits])

                client_indices_per_class[j][i] = list(splits[used_splits])
                client_dist_per_class[j][i] += len(splits[used_splits])
                used_splits += 1
            else:
                # for l in range (len(classes_lists[i])-len(splits)):
                #     # clients_lists[j][i+l] = 0
                classes_lists[i] = classes_lists[i][0:len(splits)]
                break

            for k in range(len(splits) - used_splits):  # add more chunks randomly
                p = random.choice(classes_lists[i])
                clients_splits[p].extend(splits[used_splits])
                client_indices_per_class[p][i] = list(client_indices_per_class[p][i])
                client_indices_per_class[p][i].extend(splits[used_splits])
                client_dist_per_class[p][i] += len(splits[used_splits])
                used_splits += 1

    datasets = []
    for client_idx in range(num_clients):
        indices = clients_splits[client_idx]
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = indices
            datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=indices)
            datasets.append(subset)
    log.info("Number of data points per client")
    log.info({f"client_{i}": len(ds.indices) for i, ds in enumerate(datasets)})
    return datasets


def new_getattr(self, name):
    """Search recursively for attributes under self.dataset."""
    dataset = self
    if name[:2] == "__":
        raise AttributeError(name)
    while hasattr(dataset, 'dataset'):
        dataset = dataset.dataset
        if hasattr(dataset, name):
            return getattr(dataset, name)
    raise AttributeError(name)




def add_attrs(*given_subsets: [Subset]):
    for subsets in given_subsets:
        for subset in subsets:
            subset.__class__.__getattr__ = new_getattr


def split_dataset_train_val(train_dataset, val_split, seed=49, val_dataset=None):
    targets = train_dataset.targets
    indices = np.arange(len(targets))
    train_idx, val_idx = train_test_split(
        indices,
        test_size=val_split, stratify=targets, random_state=seed
    )
    train_subset = Subset(train_dataset, indices=train_idx)
    val_subset = Subset(val_dataset if val_dataset else train_dataset, indices=val_idx)
    return train_subset, val_subset


def no_split(dataset, *args, **kwargs):
    return [Subset(dataset, indices=list(np.arange(len(dataset))))]


def random_client_data_split(dataset, num_clients, seed):
    """
    Plain random data split amoung clients
    args:
    dataset: pytorch dataset object
    num_clients: int
    seed: int for fixing the splits
    Returns:
    List of Dataset subset object of length=num_clients
    """
    ds_len = len(dataset)
    split_sizes = [
        ds_len // num_clients if i != num_clients - 1 else ds_len - (ds_len // num_clients * (num_clients - 1))
        for i in range(num_clients)
    ]
    assert ds_len == sum(split_sizes)
    gen = torch.Generator().manual_seed(seed)  # to preserve the same split everytime
    datasets = random_split(dataset=dataset, lengths=split_sizes, generator=gen)
    assert all((set(p0).isdisjoint(set(p1))) for p0, p1 in itertools.combinations([ds.indices for ds in datasets], 2))
    # for ds in datasets:
    #     ds.__class__.__getattr__ = new_getattr  # trick to get the attrs of the original dataset
    return datasets


def stratified_client_data_split(dataset, num_clients, seed):
    """
    Data split with balanced class distrbution amoung clients
    args:
    dataset: pytorch dataset object
    num_clients: int
    seed: int for fixing the splits
    Returns:
    List of Dataset subset object of length=num_clients
    """
    prec = [
        100 // num_clients if i != num_clients - 1
        else 100 - (100 // num_clients * (num_clients - 1))
        for i in range(num_clients)
    ]
    prec = [p / (100 - (p * i)) for i, p in enumerate(prec)]
    splits_idx = []

    remaining_idx = dataset.indices if isinstance(dataset, Subset) else np.arange(len(dataset))
    targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets
    remaining_targets = targets[remaining_idx]
    # if given a subset handle indices carefully
    for i in range(num_clients - 1):
        client_idx, remaining_idx = train_test_split(
            remaining_idx,
            train_size=prec[i], stratify=remaining_targets, random_state=seed
        )
        remaining_targets = targets[remaining_idx]
        splits_idx.append(list(client_idx))

    splits_idx.append(list(remaining_idx))
    datasets = []
    for split in splits_idx:
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = split
            datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=split)
            # subset.__class__.__getattr__ = new_getattr  # trick to get the attrs of the original dataset
            datasets.append(subset)

    return datasets


def specified_client_data_split(
        dataset, num_clients, seed,
        clients_classes_to_zero, set_close_to_zero,
        dist_type="uniform",
        prec=None
):
    """
    Data split with specifc class distrbution amoung clients
    args:
    dataset: pytorch dataset object
    num_clients: int
    seed: int for fixing the splits
    clients_classes_to_zero: list of lists to specify which classes indicies to zero out for each client
    set_close_to_zero: bool, option to make the probability very small instated of zero
    dist_type: str, specify each version of the zero out function to use, currently avaliable options are:
        `uniform` or `random`
    Returns:
    List of Dataset subset object of length=num_clients
    """
    # TODO: USE SEED HERE!
    assert dist_type == "uniform" or dist_type == "random" or dist_type == "controlled"
    if dist_type == "uniform":
        zero_out_func = zero_out_classes_uniform
    elif dist_type == "random":
        zero_out_func = zero_out_classes_random
    else:
        zero_out_func = zero_out_classes_controlled

    indices = np.array(dataset.indices) if isinstance(dataset, Subset) else np.arange(len(dataset))
    targets = dataset.dataset.targets if isinstance(dataset, Subset) else dataset.targets

    num_classes = len(set(targets.numpy()))
    classes_dist = zero_out_func(
        num_classes=num_classes, num_clients=num_clients,
        clients_classes_to_zero=clients_classes_to_zero, seed=seed,
        set_close_to_zero=set_close_to_zero, prec=prec
    )

    df = pd.DataFrame({"x": indices, "y": targets.numpy()[indices]})
    splits = [list() for _ in range(num_clients)]

    updated_classes_dist = []  # fix dist of classes after total count is reduces
    for cls_idx in range(num_classes):
        output = []
        for i in range(num_clients):
            output.append(
                classes_dist[i][cls_idx] / (
                        1 - sum([classes_dist[j][cls_idx] for j in range(0, i) if classes_dist[j][cls_idx] != 0]))
            )
        updated_classes_dist.append(output)

    for cls_idx in range(num_classes):
        for client_idx in range(num_clients):
            current_avaliable_cls_idx = df.loc[df.y == cls_idx]['x'].values
            client_portion_size = len(df.loc[df.y == cls_idx]) * updated_classes_dist[cls_idx][client_idx]
            client_selected_idx = np.random.choice(
                a=current_avaliable_cls_idx,
                size=round(client_portion_size),
                replace=False
            )
            splits[client_idx].extend(client_selected_idx)
            # df = df.loc[~df.index.isin(client_selected_idx)]  # remove already selected indicies
            df = df[~(df["x"].isin(client_selected_idx))]  # remove already selected indicies, but better
    datasets = []
    assert len(indices) == len([idx for split in splits for idx in split])
    assert all((set(p0).isdisjoint(set(p1))) for p0, p1 in itertools.combinations([split for split in splits], 2))
    for split in splits:
        if isinstance(dataset, Subset):
            subset = deepcopy(dataset)
            subset.indices = split
            datasets.append(subset)
        else:
            subset = Subset(dataset=dataset, indices=split)
            # subset.__class__.__getattr__ = new_getattr  # trick to get the attrs of the original dataset
            datasets.append(subset)

    return datasets


def zero_out_classes_random(num_classes, num_clients, clients_classes_to_zero: list[list], seed,
                            set_close_to_zero=False, **kwargs):
    """
    Given a list that specify which class indices to zero out for each client, this function
    generates a matrix of size num_clients x num_classes, that has random distrbution of classes
    as well as zero probability for the specified locations
    Args:
    num_classes: int
    num_clients: int
    clients_classes_to_zero: list of lists to specify which classes indicies to zero out for each client
    set_close_to_zero: bool, option to make the probability very small instated of zero
    Returns:
    matrix of size num_clients x num_classes with probability distrbution of classes for each client
    """
    random.seed(seed)
    assert len(clients_classes_to_zero[-1]) == 0, \
        "last client need to have no zeros so it get all the remaining unpicked data, TODO later!"
    classes_dist_per_clients = np.zeros((num_clients, num_classes))
    assert len(classes_dist_per_clients) == num_clients
    for cls_idx in range(num_classes):
        rem = 1.0
        for client_idx, client_zero_out in enumerate(clients_classes_to_zero):
            if client_zero_out:
                assert min(client_zero_out) >= 0
                assert max(client_zero_out) < num_classes
            if cls_idx in client_zero_out:
                zero = random.uniform(0.001, 0.01) if set_close_to_zero else 0
                classes_dist_per_clients[client_idx, cls_idx] = zero
                rem -= zero
            else:
                p = random.random() * rem if client_idx != (num_clients - 1) else rem
                classes_dist_per_clients[client_idx, cls_idx] = p
                rem -= p
                if client_idx == (num_clients - 1):
                    assert rem == 0
        assert np.isclose(
            np.sum(classes_dist_per_clients, axis=0)[cls_idx],
            1
        ), f"sum is {np.sum(classes_dist_per_clients, axis=0)[cls_idx]}"
    return classes_dist_per_clients


def zero_out_classes_uniform(num_classes, num_clients, clients_classes_to_zero: list[list], seed,
                             set_close_to_zero=False, **kwargs):
    """
    Given a list that specify which class indices to zero out for each client, this function
    generates a matrix of size num_clients x num_classes, that has UNIFORM distrbution of classes
    as well as zero probability for the specified locations
    Args:
    num_classes: int
    num_clients: int
    clients_classes_to_zero: list of lists to specify which classes indicies to zero out for each client
    set_close_to_zero: bool, option to make the probability very small instated of zero
    Returns:
    matrix of size num_clients x num_classes with probability distrbution of classes for each client
    """
    random.seed(seed)
    assert len(clients_classes_to_zero[-1]) == 0, \
        "last client need to have no zeros so it get all the remaining unpicked data, TODO later!"
    classes_dist_per_clients = np.zeros((num_clients, num_classes))
    assert len(classes_dist_per_clients) == num_clients
    assert len(clients_classes_to_zero) == num_clients
    counts = Counter([item for sublist in clients_classes_to_zero for item in sublist])
    for cls_idx in range(num_classes):
        rem = 1.0
        for client_idx, client_zero_out in enumerate(clients_classes_to_zero):
            if client_zero_out:
                assert min(client_zero_out) >= 0
                assert max(client_zero_out) < num_classes
            if cls_idx in client_zero_out:
                zero = random.uniform(0.001, 0.01) if set_close_to_zero else 0
                classes_dist_per_clients[client_idx, cls_idx] = zero
                rem -= zero
            else:
                p = 1 / (num_clients - counts[cls_idx]) if client_idx != (num_clients - 1) else rem
                classes_dist_per_clients[client_idx, cls_idx] = p
                rem -= p
                if client_idx == (num_clients - 1):
                    assert rem == 0
        assert np.isclose(
            np.sum(classes_dist_per_clients, axis=0)[cls_idx],
            1
        ), f"sum is {np.sum(classes_dist_per_clients, axis=0)[cls_idx]}"
    return classes_dist_per_clients


def zero_out_classes_controlled(num_classes, num_clients, clients_classes_to_zero: list[list], seed,
                                prec: list, set_close_to_zero=False):
    """
    Given a list that specify which class indices to zero out for each client, this function
    generates a matrix of size num_clients x num_classes, that has a controlled precentage of the data with a uniform
    sampling of classes of classes as well as zero probability for the specified locations
    Args:
    num_classes: int
    num_clients: int
    clients_classes_to_zero: list of lists to specify which classes indicies to zero out for each client
    prec: list with same size as number of clients specifing how much of the data they get
    set_close_to_zero: bool, option to make the probability very small instated of zero
    Returns:
    matrix of size num_clients x num_classes with probability distrbution of classes for each client
    """
    random.seed(seed)
    assert sum(prec) == 1, "prec list must sum up to 1!"
    assert len(clients_classes_to_zero[-1]) == 0, \
        "last client need to have no zeros so it get all the remaining unpicked data, TODO later!"
    classes_dist_per_clients = np.zeros((num_clients, num_classes))
    assert len(classes_dist_per_clients) == num_clients
    assert len(clients_classes_to_zero) == num_clients

    classes_dist_per_clients = np.zeros((num_clients, num_classes))
    classes_missing_prec = np.zeros((num_clients, num_classes))

    counts = Counter([item for sublist in clients_classes_to_zero for item in sublist])

    for client_idx, client_zero_out in enumerate(clients_classes_to_zero):
        for cls_idx in client_zero_out:
            classes_missing_prec[:, cls_idx] += prec[client_idx]
            # classes_missing_prec[client_idx, cls_idx] = 0

    for client_idx, client_zero_out in enumerate(clients_classes_to_zero):
        for cls_idx in client_zero_out:
            # classes_missing_prec[:, cls_idx] += prec[client_idx]
            classes_missing_prec[client_idx, cls_idx] = 0

    for client_idx in range(num_clients):
        for cls_idx, missing_times in counts.items():
            ignore_idx = np.where(classes_missing_prec[:, cls_idx] == 0)[0]
            classes_missing_prec[client_idx, cls_idx] = (classes_missing_prec[client_idx, cls_idx]) * (
                        prec[client_idx] / (sum([p for i, p in enumerate(prec) if i not in ignore_idx])))

    for cls_idx in range(num_classes):
        for client_idx, client_zero_out in enumerate(clients_classes_to_zero):
            if client_zero_out:
                assert min(client_zero_out) >= 0
                assert max(client_zero_out) < num_classes
            if cls_idx in client_zero_out:
                zero = 0  # random.uniform(0.001, 0.01)
                classes_dist_per_clients[client_idx, cls_idx] = zero
            else:
                p = prec[client_idx] + classes_missing_prec[client_idx, cls_idx]
                classes_dist_per_clients[client_idx, cls_idx] = p
        assert np.isclose(
            np.sum(classes_dist_per_clients, axis=0)[cls_idx],
            1
        ), f"sum is {np.sum(classes_dist_per_clients, axis=0)[cls_idx]} for class {cls_idx}"

    return classes_dist_per_clients


def split_subsets_train_val(subsets, val_precent, seed, val_dataset: Dataset = None):
    """
    split clients subsets into train/val sets
    Args:
        val_dataset: give if you have a val dataset that have different transforms than the train dataset
    """
    train_sets = []
    val_sets = []
    for subset in subsets:
        train_indices, val_indices = train_test_split(subset.indices, test_size=val_precent, random_state=seed)
        train_subset = deepcopy(subset)
        train_subset.indices = train_indices

        if val_dataset:
            val_subset = Subset(val_dataset, indices=val_indices)
        else:
            val_subset = deepcopy(subset)
            val_subset.indices = val_indices

        train_subset.__class__.__getattr__ = new_getattr
        val_subset.__class__.__getattr__ = new_getattr  # trick to get the attrs of the original dataset

        train_sets.append(train_subset)
        val_sets.append(val_subset)
    # ensure that all indices are disjoints and splits are correct
    assert all((set(p0).isdisjoint(set(p1))) for p0, p1 in itertools.combinations([s.indices for s in train_sets], 2))
    assert all((set(p0).isdisjoint(set(p1))) for p0, p1 in itertools.combinations([s.indices for s in val_sets], 2))
    assert all(
        (set(p0).isdisjoint(set(p1))) for p0, p1 in itertools.combinations(
            [s.indices for s in itertools.chain(train_sets, val_sets)],
            2
        )
    )
    return train_sets, val_sets







def randomList(m, n):
    seed = 42
    random.seed(seed)
    # m clients and n classes
    arr = [0] * m;

    # To make the sum of the final list as n
    for i in range(n):
        # Increment any random element
        # from the array by 1
        if i % 25:  # r1: %2, r2: %25
            t = randint(0, 8)
        else:
            t = randint(0, n)

        arr[randint(0, t) % m] += 1;

    random.shuffle(arr)
    return (arr)









